{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "## 5.3.1 RNN层的实现"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2023-05-09T04:24:47.820046Z",
     "start_time": "2023-05-09T04:24:47.684046700Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "class RNN:\n",
    "    def __init__(self, Wx, Wh, b):\n",
    "        self.params = [Wx, Wh, b]\n",
    "        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]\n",
    "        self.cache = None\n",
    "\n",
    "    def forword(self, x, h_prev):\n",
    "        Wx, Wh, b = self.params\n",
    "        t = np.dot(h_prev, Wh) + np.dot(x, Wx) + b\n",
    "        h_next = np.tanh(t)\n",
    "\n",
    "        self.cache = (x, h_prev, h_next)\n",
    "        return h_next\n",
    "\n",
    "    def backward(self, dh_next):\n",
    "        Wx, Wh, b = self.params\n",
    "        x, h_prev, h_next = self.cache\n",
    "\n",
    "        dt = dh_next * (1 - h_next ** 2)\n",
    "        db = np.sum(dt, axis=0)\n",
    "        dWh = np.dot(h_prev.T, dt)\n",
    "        dh_prev = np.dot(x.T, dt)\n",
    "        dWx = np.dot(x.T, dt)\n",
    "        dx = np.dot(dt, Wx.T)\n",
    "\n",
    "        self.grads[0][...] = dWx\n",
    "        self.grads[1][...] = dWh\n",
    "        self.grads[2][...] = db\n",
    "\n",
    "        return dx, dh_prev\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "class TimeRNN:\n",
    "    def __init__(self, Wx, Wh, b, stateful=False):\n",
    "        self.params = [Wx, Wh, b]\n",
    "        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]\n",
    "        self.layers = None\n",
    "\n",
    "        self.h, self.dh = None, None\n",
    "        self.stateful = stateful\n",
    "\n",
    "    def set_state(self, h):\n",
    "        self.h = h\n",
    "\n",
    "    def reset_state(self):\n",
    "        self.h = None\n",
    "\n",
    "    def forward(self, xs):\n",
    "        Wx, Wh, b = self.params\n",
    "        N, T, D = xs.shape\n",
    "        D, H = Wx.shape\n",
    "\n",
    "        self.layers = []\n",
    "        hs = np.empty((N, T, H), dtype='f')\n",
    "\n",
    "        if not self.stateful or self.h is None:\n",
    "            self.h = np.zeros((N, H), dtype='f')\n",
    "\n",
    "        for t in range(T):\n",
    "            layer = RNN(*self.params)\n",
    "            self.h = layer.forword(xs[:, t, :], self.h)\n",
    "            hs[:, t, :] = self.h\n",
    "            self.layers.append(layer)\n",
    "\n",
    "        return hs\n",
    "\n",
    "    def backward(self, dhs):\n",
    "        Wx, Wh, b = self.params\n",
    "        N, T, H = dhs.shape\n",
    "        D, H = Wx.shape\n",
    "\n",
    "        dxs = np.empty((N, T, D), dtype='f')\n",
    "        dh = 0\n",
    "        grads = [0, 0, 0]\n",
    "        for t in reversed(range(T)):\n",
    "            layer = self.layers[t]\n",
    "            dx, dh = layer.backward(dhs[:, t, :] + dh)  # 求和后的梯度\n",
    "            dxs[:, t, :] = dx\n",
    "\n",
    "            for i, grad in enumerate(layer.grads):\n",
    "                grads[i] += grad\n",
    "\n",
    "        for i, grad in enumerate(grads):\n",
    "            self.grads[i][...] = grad\n",
    "        self.dh = dh\n",
    "\n",
    "        return dxs"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T05:00:29.795047400Z",
     "start_time": "2023-05-09T05:00:29.782046700Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
